{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "EgiF12Hf1Dhs"
   },
   "source": [
    "This notebook provides examples to go along with the [textbook](http://manipulation.csail.mit.edu/rl.html).  I recommend having both windows open, side-by-side!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from pydrake.all import Rgba, RigidTransform, Sphere, StartMeshcat\n",
    "\n",
    "from manipulation.meshcat_utils import plot_surface\n",
    "\n",
    "# Optional imports (these are heavy dependencies for just this one notebook)\n",
    "# TODO(russt): Consume nevergrad without the (heavy) bayesian-optimization\n",
    "# dependency.\n",
    "ng_available = False\n",
    "try:\n",
    "    import nevergrad as ng\n",
    "\n",
    "    ng_available = True\n",
    "except ImportError:\n",
    "    print(\"nevergrad not found.\")\n",
    "    print(\"Consider 'pip3 install nevergrad'.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "meshcat = StartMeshcat()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, let's define a few interesting cost functions that have multiple local minima."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def three_hump_camel(x, y, path=None):\n",
    "    z = (2 * x**2 - 1.05 * x**4 + x**6 / 6 + x * y + y**2) / 4\n",
    "    if path:\n",
    "        pt = f\"{path}/{x}_{y}\"\n",
    "        meshcat.SetObject(pt, Sphere(0.01), Rgba(0, 0, 1, 1))\n",
    "        meshcat.SetTransform(pt, RigidTransform([x, y, z]))\n",
    "    return z\n",
    "\n",
    "\n",
    "def plot_three_hump_camel():\n",
    "    X, Y = np.meshgrid(np.arange(-2.5, 2.5, 0.05), np.arange(-3, 3, 0.05))\n",
    "    Z = three_hump_camel(X, Y)\n",
    "    # TODO(russt): Finish the per-vertex coloring variant.\n",
    "    plot_surface(meshcat, \"three_hump_camel\", X, Y, Z, wireframe=True)\n",
    "\n",
    "\n",
    "def six_hump_camel(x, y, path=None):\n",
    "    z = (\n",
    "        x**2 * (4 - 2.1 * x**2 + x**4 / 3)\n",
    "        + x * y\n",
    "        + y**2 * (-4 + 4 * y**2)\n",
    "    )\n",
    "    if path:\n",
    "        pt = f\"{path}/{x}_{y}\"\n",
    "        meshcat.SetObject(pt, Sphere(0.01), Rgba(0, 0, 1, 1))\n",
    "        meshcat.SetTransform(pt, RigidTransform([x, y, z]))\n",
    "    return z\n",
    "\n",
    "\n",
    "def plot_six_hump_camel():\n",
    "    X, Y = np.meshgrid(np.arange(-2, 2, 0.05), np.arange(-1.2, 1.2, 0.05))\n",
    "    Z = six_hump_camel(X, Y)\n",
    "    # TODO(russt): Finish the per-vertex coloring variant.\n",
    "    plot_surface(meshcat, \"six_hump_camel\", X, Y, Z, wireframe=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Black-box optimization\n",
    "\n",
    "Let's explore a few of the algorithms from Nevergrad on these simple cost landscapes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if ng_available:\n",
    "    meshcat.Delete()\n",
    "    plot_six_hump_camel()\n",
    "\n",
    "    # Note: You can print nevergrad's available optimizers using\n",
    "    # print(sorted(ng.optimizers.registry.keys()))\n",
    "\n",
    "    # Uncomment some of these to try...\n",
    "    # solver='NGOpt'\n",
    "    # solver='RandomSearch'\n",
    "    solver = \"CMA\"\n",
    "    optimizer = ng.optimizers.registry[solver](parametrization=2, budget=100)\n",
    "    recommendation = optimizer.minimize(\n",
    "        lambda x: six_hump_camel(x[0], x[1], \"NGOpt\")\n",
    "    )\n",
    "    xstar = recommendation.value\n",
    "    meshcat.SetObject(\"recommendation\", Sphere(0.02), Rgba(0, 1, 0, 1))\n",
    "    meshcat.SetTransform(\n",
    "        \"recommendation\",\n",
    "        RigidTransform(\n",
    "            [xstar[0], xstar[1], six_hump_camel(xstar[0], xstar[1])]\n",
    "        ),\n",
    "    )\n",
    "    print(xstar)  # recommended value"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "Robotic Manipulation - Geometric Pose Estimation.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3.10.8 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  },
  "metadata": {
   "interpreter": {
    "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
   }
  },
  "vscode": {
   "interpreter": {
    "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
